#include "esn.h"

void esn_network::activationFunction(gsl_vector *v){
	for(unsigned int i = 0; i < v->size; i++){
		gsl_vector_set(v, i, tanh(gsl_vector_get(v, i)));
	}
}
void esn_network::activationFunction_inv(gsl_vector *v){
	for(unsigned int i = 0; i < v->size; i++){
		gsl_vector_set(v, i, gsl_atanh(gsl_vector_get(v, i)));
	}
}
gsl_matrix *esn_network::runNetwork(gsl_matrix *input, gsl_matrix *desired, int steps, int forcing_steps){
	//input a desired maju zaznamy v stlpcoch
	if((input == NULL) && (nInputs > 0)){
		printf("WARNING: Prazdny vstup.");
		return NULL;
	}
	if((input != NULL) && (input->size1 != nInputs)){
		printf("WARNING: Neplatny rozmer vstupu.");
		return NULL;
	}
	if(desired->size1 != nOutputs){
		printf("WARNING: Neplatny rozmer vstupu.");
		return NULL;
	}

	gsl_matrix *output = gsl_matrix_calloc(nOutputs, steps);
	gsl_matrix *states = gsl_matrix_calloc(nResUnits, steps);
	gsl_vector *v_tmp;

	//gsl_vector_set_zero(y);
	//gsl_vector_set_zero(x);

	for(int i = 0; i < steps; i++){
		gsl_matrix_set_col(output, i, y);
		// store actual state x
		gsl_matrix_set_col(states, i, x);
		// compute x in next time step

		// adding internal activation
		v_tmp = gsl_vector_alloc(nResUnits);
		gsl_vector_memcpy(v_tmp, x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W, v_tmp, 0.0, x);
		gsl_vector_free(v_tmp);

		// adding feedback
		if(useFeedback){
			v_tmp = gsl_vector_alloc(nOutputs);
			if(i < forcing_steps){
				gsl_matrix_get_col(v_tmp, desired, i);
			}else{
				gsl_vector_memcpy(v_tmp, y);
			}
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_fb, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}

		// adding input
		if(nInputs > 0){
			v_tmp = gsl_vector_alloc(nInputs);
			gsl_matrix_get_col(v_tmp, input, i);
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_in, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}
		//activationFunction(x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W_out, x, 0.0, y);
	}
	print_matrix_to_file(states, "./lang/states.txt");
	gsl_matrix_free(states);

	return output;
}

void esn_network::train_offline(gsl_matrix *input, gsl_matrix *desired, unsigned int stateCollectingOfsset){
	unsigned int offset = stateCollectingOfsset; // oneskorenie zbierania aktivacii
	//input a desired maju zaznamy v stlpcoch
	if((input == NULL) && (nInputs > 0)){
		printf("WARNING: Prazdny vstup.");
		return;
	}
	if((input != NULL) && (input->size1 != nInputs)){
		printf("WARNING: Neplatny rozmer vstupu.");
		return;
	}
	if(desired->size1 != nOutputs){
		printf("WARNING: Neplatny rozmer vstupu.");
		return;
	}

	gsl_matrix *outputs = gsl_matrix_calloc(nOutputs, desired->size2 - offset);
	gsl_matrix *target = gsl_matrix_calloc(nOutputs, desired->size2 - offset);
	gsl_matrix *states = gsl_matrix_calloc(nResUnits, desired->size2 - offset);
	gsl_vector *v_tmp;

	for(unsigned int i = 0; i < desired->size2; i++){

		// store actual state x
		if(i >= offset){
			gsl_matrix_set_col(states, i - offset, x);
			gsl_matrix_set_col(outputs, i - offset, y);

			v_tmp = gsl_vector_alloc(desired->size1);
			gsl_matrix_get_col(v_tmp, desired, i);
			
			gsl_matrix_set_col(target, i - offset, v_tmp);
			gsl_vector_free(v_tmp);
		}
		// compute x in next time step
		// adding internal activation
		
		v_tmp = gsl_vector_alloc(nResUnits);
		gsl_vector_memcpy(v_tmp, x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W, v_tmp, 0.0, x);
		gsl_vector_free(v_tmp);

		// adding feedback
		if(useFeedback){
			v_tmp = gsl_vector_alloc(nOutputs);
			gsl_matrix_get_col(v_tmp, desired, i);
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_fb, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}

		// adding input
		if(nInputs > 0){
			v_tmp = gsl_vector_alloc(nInputs);
			gsl_matrix_get_col(v_tmp, input, i);
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_in, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}
		//activationFunction(x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W_out, x, 0.0, y);
	}

	//print_matrix_to_file(outputs, "outputs.txt");
	//print_matrix_to_file(states, "states.txt");
	gsl_matrix *statesT = gsl_matrix_alloc(states->size2, states->size1);
	gsl_matrix_transpose_memcpy(statesT, states);
	//print_matrix_to_file(statesT, "statesT.txt");

	gsl_matrix *statesI = pinv(statesT);

	//gsl_matrix *statesI = gsl_matrix_calloc(statesT->size1, statesT->size2);
	gsl_matrix_transpose_memcpy(statesT, statesI);
	//print_matrix_to_file(statesT, "statesI.txt");

	gsl_matrix_free(states);
	gsl_matrix_free(statesI);
	//print_matrix_to_file(desired, "des.txt");
	gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, target, statesT, 0.0, W_out);	
	//print_matrix_to_file(W_out, "wout.txt");
	gsl_matrix_free(statesT);
	gsl_matrix_free(target);
	gsl_matrix_free(outputs);
}

void esn_network::train_offline(gsl_matrix *input, gsl_matrix *desired, unsigned int stateCollectEach, unsigned int stateCollectStartAt){

	//input a desired maju zaznamy v stlpcoch
	if((input == NULL) && (nInputs > 0)){
		printf("WARNING: Prazdny vstup.");
		return;
	}
	if((input != NULL) && (input->size1 != nInputs)){
		printf("WARNING: Neplatny rozmer vstupu.");
		return;
	}
	if(desired->size1 != nOutputs){
		printf("WARNING: Neplatny rozmer vstupu.");
		return;
	}
	unsigned int storedStateCount = ceil((desired->size2 - stateCollectStartAt) / (double)stateCollectEach);
		//storedStateCount = (2 * desired->size2) / 3;
	gsl_matrix *outputs = gsl_matrix_calloc(nOutputs, storedStateCount);
	gsl_matrix *target = gsl_matrix_calloc(nOutputs, storedStateCount);
	gsl_matrix *states = gsl_matrix_calloc(nResUnits, storedStateCount);

	gsl_vector *v_tmp;
	unsigned int storeAt = 0;
	for(unsigned int i = 0; i < desired->size2; i++){
		// store actual state x
		if(((storeAt < storedStateCount) && ((((int)i - (int)stateCollectStartAt) % (int)stateCollectEach) == 0)) 
			//&& (((i % 900) < 200) || ((i % 900) >= 700) || (((i % 900) >= 300) && ((i % 900) < 400)) || (((i % 900) >= 500) && ((i % 900) < 600)))
			){
			gsl_matrix_set_col(states, storeAt, x);
			gsl_matrix_set_col(outputs, storeAt, y);

			v_tmp = gsl_vector_alloc(desired->size1);
			gsl_matrix_get_col(v_tmp, desired, i);
			
			gsl_matrix_set_col(target, storeAt, v_tmp);
			gsl_vector_free(v_tmp);
			storeAt++;
		}
		// compute x in next time step
		// adding internal activation
		
		v_tmp = gsl_vector_alloc(nResUnits);
		gsl_vector_memcpy(v_tmp, x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W, v_tmp, 0.0, x);
		gsl_vector_free(v_tmp);

		// adding feedback
		if(useFeedback){
			v_tmp = gsl_vector_alloc(nOutputs);
			gsl_matrix_get_col(v_tmp, desired, i);
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_fb, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}

		// adding input
		if(nInputs > 0){
			v_tmp = gsl_vector_alloc(nInputs);
			gsl_matrix_get_col(v_tmp, input, i);
			gsl_blas_dgemv(CblasNoTrans, 1.0, W_in, v_tmp, 1.0, x);
			gsl_vector_free(v_tmp);
		}
		//activationFunction(x);
		gsl_blas_dgemv(CblasNoTrans, 1.0, W_out, x, 0.0, y);
	}

	//print_matrix_to_file(outputs, "outputs.txt");
	//print_matrix_to_file(states, "states.txt");
	gsl_matrix *statesT = gsl_matrix_alloc(storedStateCount, states->size1);
	gsl_matrix_transpose_memcpy(statesT, states);
	//print_matrix_to_file(statesT, "statesT.txt");

	gsl_matrix *statesI = pinv(statesT);

	//gsl_matrix *statesI = gsl_matrix_calloc(statesT->size1, statesT->size2);
	gsl_matrix_transpose_memcpy(statesT, statesI);
	gsl_matrix_free(statesI);
	//print_matrix_to_file(statesT, "statesI.txt");


	gsl_matrix_free(states);
	
	//print_matrix_to_file(desired, "des.txt");
	gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, target, statesT, 0.0, W_out);	
	//print_matrix_to_file(W_out, "wout.txt");
	gsl_matrix_free(statesT);
	gsl_matrix_free(target);
	gsl_matrix_free(outputs);
}


gsl_matrix *esn_network::get_matrix_W(){
	return W;
}
gsl_matrix *esn_network::get_matrix_Win(){
	return W_in;
}
gsl_matrix *esn_network::get_matrix_Wout(){
	return W_out;
}
gsl_matrix *esn_network::get_matrix_Wfb(){
	return W_fb;
}
void esn_network::print_reservoir(){
	for(int i = 0; i < nResUnits; i++){
		for(int j = 0; j < nResUnits; j++){
			printf("%+.4f ", gsl_matrix_get(W, i, j));	
		}
		printf("\n");
	}
}

gsl_matrix *esn_network::create_reservoir(int nResUnits, double connectivity, double resScaling){

	gsl_matrix *result = gsl_matrix_calloc(nResUnits, nResUnits);

	connectivity = (connectivity > 1) ? 1 : connectivity;
	int nNotZeros = (int)ceil(nResUnits * nResUnits * connectivity);
	while(nNotZeros != 0){
		int i = rand() % nResUnits;
		int j = rand() % nResUnits;
		if(gsl_matrix_get(result, i, j) == 0){
			double value = (rand() / (double)RAND_MAX) - 0.5;
			gsl_matrix_set(result, i, j, value);
			nNotZeros--;
		}
	}
	
	double scale_factor = getAbsMaxEig(result);
	gsl_matrix_scale(result,  resScaling / scale_factor);
	
	return result;
}

void esn_network::saveNetwork(){
		char *path = "./esn/";
		mkdir(path);

		stringstream filename;
		filename << path; 
		filename <<"W_out";
		filename << ".txt";

		char *str = (char*)malloc(sizeof(char) * filename.str().length() + 1);
		strcpy(str, filename.str().c_str());

		FILE *f = fopen(str, "w");
		gsl_matrix_fprintf (f, W_out, "%e");
		fclose(f);
		
		free(str);

		print_matrix_to_file(W_in, "./esn/W_in.txt");
		print_matrix_to_file(W, "./esn/W.txt");
		print_matrix_to_file(W_fb, "./esn/W_fb.txt");
}

void esn_network::loadNetwork(){
		
		char *path = "./esn/";
		stringstream filename;
		filename << path; 
		filename <<"W_out";
		filename << ".txt";

		char *str = (char*)malloc(sizeof(char) * filename.str().length() + 1);
		strcpy(str, filename.str().c_str());

		FILE *f = fopen(str, "r");
		gsl_matrix_fscanf(f, W_out);
		fclose(f);

		f = fopen("./esn/W_in.txt", "r");
		gsl_matrix_fscanf(f, W_in);
		fclose(f);

		f = fopen("./esn/W.txt", "r");
		gsl_matrix_fscanf(f, W);
		fclose(f);

		f = fopen("./esn/W_fb.txt", "r");
		gsl_matrix_fscanf(f, W_fb);
		fclose(f);

		free(str);
}